import torch
import torch.nn as nn
from torch.optim import AdamW, SGD
from torch.nn import functional as F
from avalanche.evaluation.metrics.accuracy import Accuracy
from tqdm import tqdm
from timm.models import create_model
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.loss import LabelSmoothingCrossEntropy
from argparse import ArgumentParser
from utils import *

import numpy as np
import psutil


from models import (
    vision_transformer_adapter,
    vision_transformer_adaptformer,
    swin_transformer_adapter,
    swin_transformer_adaptformer,
)

from logger import create_logger


def train(
    config, model, criterion, dl, opt, scheduler, logger, epoch
):

    model.train()
    model = model.cuda()

    for ep in tqdm(range(epoch)):
        model.train()
        model = model.cuda()
        # pbar = tqdm(dl)
        for i, batch in enumerate(dl):
            # torch.cuda.empty_cache()
            x, y = batch[0].cuda(), batch[1].cuda()
            out = model(x)

            # loss = F.cross_entropy(out, y)
            loss = criterion(out, y)
            opt.zero_grad()
            loss.backward()
            opt.step()

        if scheduler is not None:
            scheduler.step(ep)

        ram_used = psutil.virtual_memory().used / (1024.0 * 1024.0)
        memory_used = torch.cuda.max_memory_allocated() / (
            1024.0 * 1024.0
        )
        # logger.info('RAM used: '+str(ram_used)+' memory: '+str(memory_used)+'MB')

        if ep % 10 == 9:
            # memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            acc = test(model, test_dl)
            if acc > config["best_acc"]:
                config["best_acc"] = acc
                # save('vit_sct', config['task'], config['name'], model, acc, ep)
            logger.info(
                str(ep)
                + " "
                + str(acc)
                + " memory: "
                + str(memory_used)
                + "MB"
            )
    model = model.cpu()
    return model


@torch.no_grad()
def test(model, dl):
    model.eval()
    acc = Accuracy()
    # pbar = tqdm(dl)
    model = model.cuda()
    for batch in dl:  # pbar:
        torch.cuda.empty_cache()
        x, y = batch[0].cuda(), batch[1].cuda()
        out = model(x).data
        # acc.update(out.argmax(dim=1).view(-1), y, 0)
        acc.update(out.argmax(dim=1).view(-1), y)

    # return acc.result()[0]
    return acc.result()


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--wd", type=float, default=1e-4)
    parser.add_argument("--eval", type=str, default="True")
    parser.add_argument("--dpr", type=float, default=0.1)
    parser.add_argument("--topN", type=int, default=None)
    parser.add_argument(
        "--model",
        type=str,
        default="vit_base_patch16_224_in21k_adapter",
    )
    parser.add_argument(
        "--model_ckp",
        type=str,
        default="./released_models/ViT-B_16.npz",
    )
    parser.add_argument(
        "--model_type", type=str, default="vit_adapter"
    )
    parser.add_argument("--task", type=str, default="vtab")
    parser.add_argument("--dataset", type=str, default="cifar")
    parser.add_argument("--tuning_mode", type=str, default="adapter")

    args = parser.parse_args()
    print(args)

    set_seed(args.seed)
    config = get_config("model_adapter", args.task, args.dataset)

    if args.topN is not None:
        topN = args.topN
    else:
        topN = config["topN"]

    exp_base_path = "./output/%s/%s/%s" % (
        args.model_type,
        args.task,
        config["name"].replace("adapter", args.tuning_mode)
        + "_dim_%d" % (topN),
    )
    mkdirss(exp_base_path)
    logger = create_logger(output_dir=exp_base_path, name="training")

    logger.info(args)
    logger.info(config)

    ## prepare training data
    if args.eval == "True":
        evalflag = True
    else:
        evalflag = False

    if args.task == "vtab":
        from vtab import *

        basedir = "data/vtab-1k"
    elif args.task == "fgvc":
        from fgvc import *

    if "train_aug" in config.keys():
        train_aug = config["train_aug"]
    else:
        train_aug = False

    train_dl, test_dl = get_data(
        basedir,
        args.dataset,
        logger,
        evaluate=evalflag,
        train_aug=train_aug,
        batch_size=config["batch_size"],
    )

    if "swin" in args.model:
        model = create_model(
            args.model,
            pretrained=False,
            drop_path_rate=args.dpr,
            tuning_mode=args.tuning_mode,
            topN=topN,
        )
        model.load_state_dict(
            torch.load(args.model_ckp)["model"], False
        )  # not include adapt module
    else:
        model = create_model(
            args.model,
            checkpoint_path=args.model_ckp,
            drop_path_rate=args.dpr,
            tuning_mode=args.tuning_mode,
            topN=topN,
        )

    model.reset_classifier(config["class_num"])

    logger.info(str(model))

    config["best_acc"] = 0.
    config["task"] = args.task

    trainable = []
    for n, p in model.named_parameters():
        if "adapter" in n or "adaptmlp" in n or "head" in n:
            trainable.append(p)
            logger.info(str(n))
        else:
            p.requires_grad = False

    opt = AdamW(trainable, lr=args.lr, weight_decay=args.wd)

    if "cycle_decay" in config.keys():
        cycle_decay = config["cycle_decay"]
    else:
        cycle_decay = 0.1

    scheduler = CosineLRScheduler(
        opt,
        t_initial=config["epochs"],
        warmup_t=config["warmup_epochs"],
        lr_min=1e-5,
        warmup_lr_init=1e-6,
        cycle_decay=cycle_decay,
    )

    n_parameters = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )

    logger.info(f"number of extra params: {n_parameters}")

    if config["labelsmoothing"] > 0.0:
        # label smoothing
        criterion = LabelSmoothingCrossEntropy(
            smoothing=config["labelsmoothing"]
        )
        logger.info("label smoothing")
    else:
        criterion = torch.nn.CrossEntropyLoss()
        logger.info("CrossEntropyLoss")

    model = train(
        config,
        model,
        criterion,
        train_dl,
        opt,
        scheduler,
        logger,
        config["epochs"],
    )
    print(config["best_acc"])

    logger.info("end")
